% Analyze the data from the spiking network for the VisuoMotor Assoc. task.
% PSTH for each cell in area, 'areaname';  single PSTH's are a really bad idea.  What
% I should do is to look at ensemble activity and try to classify based on
% clusters of cells in a certain area (multi-electrode analysis).  How? I
% would need to cluster, or do pattern classification analysis (simple
% regression may work.)  If I take the group and take mean for each and use
% this as a datapoint for nearest neighbor classification, I may be able to
% classify well.  This is what we did before with Darwin VIII patterns, and
% it would work OK even if the state of the ensemble is changing every once
% in a while, as long as we take short enough windows, say 100-250 msec.?
%
%  load patterns.dat to read the file 'patterns.dat' containing the pattern number for each msec
%	Assumes pattern numbers start at 0.
%  Assumes that defineAndPrint network was run in the current environment.
%
%  function [rs sig_cells] = analyzeVMAstats( areaname, phase_start_msec, phase_end_msec , start_time_msecs, end_time_msecs,significance_level, save_seconds)
% 
%  phase_start_msec, phase_end_msec:   the start and
%		end times (inclusive) within the trial to analyze the data.  Start counting at 0 for the first msec.
%  start_time_msec, end_time_msec: start and end times of the experiments.
%
%   save_seconds = the save_Seconds parameter set in spnet.cpp

%  RETURNS:
%	matrix of probabilities from ranksum, comparing spike counts of neurons during the specified task phase time
%	for significantly different medians.  rs(neuron, pattern i , pattern j);
%
function [rs sig_cells] = analyzeVMAstats( areaname, phase_start_msec, phase_end_msec , start_time_msecs, end_time_msecs,significance_level, save_seconds)

global nG
global S
global mfr

mfr=[];

load ('patternhist.dat');

phase_start_msec = phase_start_msec + 1;
phase_end_msec = phase_end_msec + 1;


index = getGroupInd( areaname );

if nargin < 6
	significance_level=0.05;
    save_seconds = 10;
else 
    if nargin < 7
        save_seconds = 10;
    end
end

[first last]= getGroupNeuronIndices( nG(index).name );
first = first + 1;
last = last + 1;
rows = nG(index).dimX;
cols = nG(index).dimY;



num_patterns = max( patternhist(:,2))+1


count = zeros( 1, num_patterns );


% find the file we need to open
%filename = ['sdat' numTOstr( ceil( (start_time_msecs+phase_start_msec)/(save_seconds*1000))*save_seconds,6) '.dat' ]
%sdat = load(filename);
%S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1);
filename = '';

trial_indices_to_analyze = find(patternhist(:,1)>start_time_msecs & patternhist(:,1)<end_time_msecs);

for i = trial_indices_to_analyze'

	t = patternhist( i, 1 );

	% find the file we need to open
	current_filename = ['sdat' numTOstr( ceil( ( t+phase_start_msec)/(save_seconds*1000))*save_seconds,6)  '.dat']

	% Load new filename if necessary.
	if strcmp(current_filename, filename) == 0
		filename = current_filename;
		sdat = load(filename);
        max_time_msec = ceil( ( t+phase_start_msec)/(save_seconds*1000))*save_seconds*1000
		S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1, max(last,max(sdat(:,1)+1)), max_time_msec );
	end

	% If this pattern presentation is split across files, ignore it.	
	end_filename = ['sdat' numTOstr( ceil( ( t +phase_end_msec)/(save_seconds*1000))*save_seconds,6)  '.dat'];
	if strcmp(end_filename, filename) == 0
		continue;
	end

	current_pattern = patternhist( i , 2 )+1;  
	count(current_pattern) = count(current_pattern) + 1;
	c=count(current_pattern);
	% Gather mean rates for stats.
    first
    last 
    (t+1+phase_start_msec)
    (t+1+phase_end_msec)
    disp('****************************')
	mfr(current_pattern).spikes(:, c ) = full(sum( S( first:last, (t+1+phase_start_msec):(t+1+phase_end_msec)), 2));

end

count

k=0;
figure(100+k); k= k + 1;


disp('Mean firing rate of cells to each pattern');
for i = 1:num_patterns


	%figure(100+k); k= k + 1;
    subplot(2,2,i);
    
	imagesc(reshape( mean(mfr(i).spikes,2)*1000/(phase_end_msec-phase_start_msec), rows, cols ))
	colormap(gray);
    colorbar;
	title(['Mean rate (sp/sec) of cells in area ' areaname ' to pattern ' num2str(i) ]);
	
end





% Show significantly preferential responses to each pattern.



for i = 1:num_patterns
	for j = 1:num_patterns
  	    if ( count(j) > 1 && count(i) > 1)
		for n = 1:nG(index).size
			[junk significant]=ranksum( mfr( i ).spikes(n, : ) , mfr( j ).spikes(n, : ) ,significance_level);
            rs(n,i,j) = significant && mean(mfr( i ).spikes(n, : )) * 0.5 > mean(mfr( j ).spikes(n, : )); 
%if sum( mfr( i ).spikes(n, :) ) > 0
% [i j]
% mfr( i ).spikes(n, : ) 
% mfr( j ).spikes(n, : ) 
%end

		end
		num_significant(i,j)=sum(rs(:,i,j));
		%if j>i
		%	figure(100+k);
		%	imagesc(reshape( rs(:,i,j), rows, cols ))
		%	colormap(gray);
		%	title(['Cells in area ' areaname ' that distinguish pattern ' num2str(i) ' from pattern ' num2str(j) ' statistically']);
		%	k=k+1;
		%end
	    end
	end
	disp('.');
end

disp('number of cells which distinguish pattern "row" from pattern "col"');
num_significant

discriminators = zeros(1,num_patterns);
merged_discriminator_map = zeros(nG(index).size,1);

for i = 1:num_patterns
	temp = ones(nG(index).size,1);
	for j = 1:num_patterns

		if i ~= j
			temp = temp & rs(:,i,j);
		end

    end

    merged_discriminator_map = merged_discriminator_map + i*temp;
    discriminators(i) = sum(temp);

	figure(100+k); k= k + 1;

	%subplot(num_patterns,1,i);
	imagesc(reshape( temp, rows, cols ))
	colormap(gray);
	title(['Cells in area ' areaname ' that distinguish pattern ' num2str(i) ' from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec)]);
    size(temp)
	disp('Significant cell numbers (start counting at 0 as in C): ');
    sig_cells(i).list = (find(temp)+first-2)';
end

figure(100+k);k=k+1;
bar(discriminators);
title(['Cells in area ' areaname ' that distinguish pattern ' num2str(i) ' from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec)]);


figure(100+k);k=k+1;
imagesc(reshape( merged_discriminator_map, rows, cols ))
title(['Cells in area ' areaname ' that distinguish each pattern from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec)]);

merged_discriminator_map
for i = find(merged_discriminator_map')
   disp(['Responses for unit ' num2str(i)]);
   for j = 1:num_patterns 
      disp([j mfr(j).spikes(i,:)]); 
   end
end
